-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[rllib] Refactor rllib to have a common sample collection pathway #2149
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
|
Test FAILed. |
| observation_filter (str): Name of observation filter to use. | ||
| registry (tune.Registry): Tune object registry. Pass in the value | ||
| from tune.registry.get_registry() if you're having trouble | ||
| resolving objects registered in tune. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I were only using RLlib, I wouldn't know what Tune is, and this wouldn't be that informative. I think we should fix up this documentation so that it describes the functionality of the registry (and refer to Tune for more information)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
| resolving objects registered in tune. | ||
| env_config (dict): Config to pass to the env creator. | ||
| model_config (dict): Config to use when creating the policy model. | ||
| policy_config (dict): Config to pass to the policy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that this is one of the core primitives that people would use, it makes sense to include examples for usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| loss_inputs=self.loss_in, is_training=self.is_training, | ||
| state_inputs=self.state_in, state_outputs=self.state_out) | ||
|
|
||
| # TODO(ekl) move session creation and init to CommonPolicyEvaluator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't session creation already in CommonPolicyEvaluator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
|
||
| # TODO(rliaw): Can consider exposing these parameters | ||
| self.sess = tf.Session(graph=self.g, config=tf.ConfigProto( | ||
| intra_op_parallelism_threads=1, inter_op_parallelism_threads=2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you leave a TODO somewhere to make sure A3C creates a session with these such parameters? It affects performance quite a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
python/ray/rllib/bc/policy.py
Outdated
| if self.summarize: | ||
| bs = tf.to_float(tf.shape(self.x)[0]) | ||
| tf.summary.scalar("model/policy_loss", self.pi_loss / bs) | ||
| tf.summary.scalar("model/policy_graph", self.pi_loss / bs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't make that much sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops fixed
| def extra_apply_grad_fetches(self): | ||
| return {} # e.g., batch norm updates | ||
|
|
||
| def optimizer(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this creates a new Optimizer every time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only called once.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to make it private (or idempotent)?
It's not stateless, and if public, it will show up on autocomplete tools (IPython, Jupyter, etc) and cause headaches. People are using rllib on the notebook setting already and presumably a lot more after this refactor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just not an issue. We are calling this ourselves, not the user, so it's impossible for them to screw it up.
| feed_dict = self.extra_compute_action_feed_dict() | ||
| feed_dict[self._obs_input] = obs_batch | ||
| feed_dict[self._is_training] = is_training | ||
| for ph, value in zip(self._state_inputs, state_batches): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this guaranteed to be ordered correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's a list
| DQNEvaluator) | ||
| num_gpus=self.config["num_gpus_per_worker"]) | ||
| self.remote_evaluators = [ | ||
| remote_cls.remote( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
side thought: it might even be cleaner if
remote_cls = CommonPolicyEvaluator.as_remote( ... )
remote_evaluators = [remote_cls(args) for i in range(num_workers)]
where remote_cls hides the ray cls.remote functionality
richardliaw
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some questions. Do we have a list or issue where all the refactoring is centered around?
General list of later todos:
- PyTorch Policy Graph
- Moving PPO onto the common evaluator
- consider a better way of managing exploration
- managing filters?
Other nit:
- Perhaps consider exposing something instead of LocalSyncReplay - something that puts the
for loopof evaluation up front, and then think about the process of going from single thread to multi-process/multi-machine and making that process easy to do.
|
|
||
| return func(self) | ||
|
|
||
| def for_policy(self, func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for_policy naming is a bit odd, but we can revisit this..
| preprocessor_pref="rllib", | ||
| sample_async=False, | ||
| compress_observations=False, | ||
| consumer_buffer_size=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
someone somewhere is going to need to explain what a "consumer" is to the user
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just removed for now.
| episode_len_mean=mean_100ep_length, | ||
| episodes_total=num_episodes, | ||
| timesteps_this_iter=self.global_timestep - start_timestep, | ||
| exp_vals = [self.exploration0.value(self.global_timestep)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it makes sense to have the evaluator manage exploration.
This is fine to do in a followup discussion...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, if we expose some "global stats" object then it could.
| from ray.rllib.utils.filter import get_filter, MeanStdFilter | ||
| from ray.rllib.utils.process_rollout import process_rollout | ||
| from ray.rllib.ppo.loss import ProximalPolicyLoss | ||
| from ray.rllib.ppo.loss import ProximalPolicyGraph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this will eventually be moved onto CommonPolicyEvaluator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we should do that.
| env: PongDeterministic-v4 | ||
| run: A3C | ||
| config: | ||
| num_workers: 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is for tuned examples right? ie, examples where our configurations are supposed to be SOTA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| Box(0.0, 1.0, (5,), dtype=np.float32)]), | ||
| } | ||
|
|
||
| # (alg, action_space, obs_space) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just throwing unsupported now
| a2 = get_mean_action(alg2, obs) | ||
| print("Checking computed actions", alg1, obs, a1, a2) | ||
| assert abs(a1 - a2) < .1, (a1, a2) | ||
| if abs(a1 - a2) > .1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.allclose is probably the better thing to use
| self.config = config | ||
|
|
||
| # Technically not needed when not remote | ||
| self.obs_filter = get_filter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are these functionalities completely supported in the refactoring (ie, saving restoring)? if not, we should probably leave a couple notes/warnings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They should be
You can always copy paste the code and run it directly right? I don't think policy optimizers are required unless your actually putting your algorithm into rllib. I kind of imagine the process as follows:
You could also imagine an even lower level step where you use VectorEnv directly. |
|
Test PASSed. |
|
Test FAILed. |
|
@richardliaw this is ready for review |
python/ray/rllib/a3c/a3c.py
Outdated
| self.local_evaluator.restore(extra_data["local_state"]) | ||
|
|
||
| def compute_action(self, observation): | ||
| def compute_action(self, observation, state=[]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to avoid using mutable objects as default values, perhaps
state=None and state = [] if state is None else state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
Test PASSed. |
python/ray/rllib/dqn/dqn.py
Outdated
| def compute_action(self, observation): | ||
| return self.local_evaluator.dqn_graph.act( | ||
| self.local_evaluator.sess, np.array(observation)[None], 0.0)[0] | ||
| def compute_action(self, observation, state=[]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment here about default arguments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| remote_evaluators = [ | ||
| remote_cls.remote(*evaluator_args) | ||
| for _ in range(num_workers)] | ||
| if type(evaluator_args) is list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isinstance(evaluator_args, list)
python/ray/rllib/pg/pg.py
Outdated
| def compute_action(self, obs): | ||
| action, info = self.optimizer.local_evaluator.policy.compute(obs) | ||
| return action | ||
| def compute_action(self, observation, state=[]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mutable default arg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
|
Test PASSed. |
richardliaw
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK last comments - will merge by tonight after addressed.
|
|
||
| import tensorflow as tf | ||
| import gym | ||
| from ray.rllib.utils.error import UnsupportedSpaceException |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: space between ray imports and non-ray
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| def extra_apply_grad_fetches(self): | ||
| return {} # e.g., batch norm updates | ||
|
|
||
| def optimizer(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to make it private (or idempotent)?
It's not stateless, and if public, it will show up on autocomplete tools (IPython, Jupyter, etc) and cause headaches. People are using rllib on the notebook setting already and presumably a lot more after this refactor.
| @@ -6,76 +6,7 @@ | |||
| import threading | |||
| from collections import namedtuple | |||
| import numpy as np | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space between ray and non-ray imports?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| rewards=reward, | ||
| dones=terminal, | ||
| features=last_features, | ||
| new_obs=observation, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IDK where to leave this note, but we're actually doubling the number of states we need to send here (observation and last_observation). In a later optimization, we should consider addressing this (I can put this on TODOs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've been doing this all along -- but yeah, could optimize later.
| actions (np.ndarray): batch of output actions, with shape like | ||
| [BATCH_SIZE, ACTION_SHAPE]. | ||
| state_outs (list): list of RNN state output batches, if any, with | ||
| shape like [STATE_SIZE, BATCH_SIZE]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is BATCH after STATE here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is so you have a small list of big lists and not a big list of small lists.
| """Restores all local state. | ||
|
|
||
| Arguments: | ||
| state (obj): Serialized local state.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: inconsistent quote placement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
python/ray/rllib/ppo/loss.py
Outdated
| 1 + config["clip_param"]) * advantages | ||
| self.surr = tf.minimum(self.surr1, self.surr2) | ||
| self.mean_policy_loss = tf.reduce_mean(-self.surr) | ||
| self.mean_policy_graph = tf.reduce_mean(-self.surr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This naming change doesn't make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
|
Updated |
|
Test PASSed. |
|
Test failures look unrelated. |
What do these changes do?
Currently RLlib algorithms have disparate sample collection pathways. This makes supporting common functionality such as LSTMs, env vectorization, batch norm, and multi-agent hard to do in a generic way.
This PR adds a
CommonPolicyEvaluatorclass which is responsible for routing observations toPolicyandTFPolicyinstances. In the multi-agent case, this many involve batching and routing observations to several local policies. It will also handle the vectorized env case.Related issue number
#2053